import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
import datetime
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
from flask import Flask, jsonify, request, render_template

model_dir = "C:/01.project/chatgpt/gpt-2/models/355M"

# 加载 GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)

# 加载 GPT-2 模型
model = TFGPT2LMHeadModel.from_pretrained(model_dir)


# 创建 Flask 实例
app = Flask(__name__)

# 定义 API 接口
@app.route('/generate', methods=['POST'])
def generate_text():
    input_text = request.get_json()['text']
    input_ids = tokenizer.encode(input_text, return_tensors='tf')
    output_ids = model.generate(max_length=100, num_return_sequences=1)
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return jsonify({'generated_text': output_text})

@app.route('/', methods=['GET', 'POST'])
def index():
    start_time = datetime.datetime.now()
    if request.method == 'POST':
        input_text = request.get_json()['text']
        input_ids = tokenizer.encode(input_text, return_tensors='tf')
        output_ids = model.generate(max_length=100, num_return_sequences=1)
        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return render_template('index.html', generated_text="(请求耗时: {}秒)".format(datetime.datetime.now() - start_time) + output_text, title_lable="基于 GPT-2 的开源模型搭建(117 million)", top_lable="GPT-2开源模型(117 million)")
    else:
        return render_template('index.html', title_lable="基于 GPT-2 的开源模型搭建(117 million)", top_lable="GPT-2开源模型(117 million)")



# 启动应用
if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5000)
